{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch==2.2 in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2) (12.3.101)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2) (1.3.0)\n",
      "Requirement already satisfied: torchvision==0.17 in /opt/conda/lib/python3.10/site-packages (0.17.0)\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (1.26.2)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (2.28.1)\n",
      "Requirement already satisfied: torch==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (2.2.0)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (10.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.0->torchvision==0.17) (12.3.101)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (2.1.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (2022.6.15)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2.0->torchvision==0.17) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2.0->torchvision==0.17) (1.3.0)\n",
      "Requirement already satisfied: matplotlib==3.5.2 in /opt/conda/lib/python3.10/site-packages (3.5.2)\n",
      "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (0.12.1)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (4.46.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (1.4.5)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (1.26.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (21.3)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (10.2.0)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (3.0.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib==3.5.2) (2.8.2)\n",
      "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib==3.5.2) (1.16.0)\n",
      "Requirement already satisfied: scikit-image==0.19.3 in /opt/conda/lib/python3.10/site-packages (0.19.3)\n",
      "Requirement already satisfied: numpy>=1.17.0 in /opt/conda/lib/python3.10/site-packages (from scikit-image==0.19.3) (1.26.2)\n",
      "Requirement already satisfied: scipy>=1.4.1 in /opt/conda/lib/python3.10/site-packages (from scikit-image==0.19.3) (1.11.4)\n",
      "Requirement already satisfied: networkx>=2.2 in /opt/conda/lib/python3.10/site-packages (from scikit-image==0.19.3) (2.8.5)\n",
      "Requirement already satisfied: pillow!=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 in /opt/conda/lib/python3.10/site-packages (from scikit-image==0.19.3) (10.2.0)\n",
      "Requirement already satisfied: imageio>=2.4.1 in /opt/conda/lib/python3.10/site-packages (from scikit-image==0.19.3) (2.33.0)\n",
      "Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/lib/python3.10/site-packages (from scikit-image==0.19.3) (2023.12.9)\n",
      "Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from scikit-image==0.19.3) (1.5.0)\n",
      "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from scikit-image==0.19.3) (21.3)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->scikit-image==0.19.3) (3.0.9)\n"
     ]
    }
   ],
   "source": [
    "!pip install torch==2.2\n",
    "!pip install torchvision==0.17\n",
    "!pip install matplotlib==3.5.2\n",
    "!pip install scikit-image==0.19.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the U-Net based Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UNetGenerator(nn.Module):\n",
    "    def __init__(self, chnls_in=3, chnls_op=3):\n",
    "        super(UNetGenerator, self).__init__()\n",
    "        self.down_conv_layer_1 = DownConvBlock(chnls_in, 64, norm=False)\n",
    "        self.down_conv_layer_2 = DownConvBlock(64, 128)\n",
    "        self.down_conv_layer_3 = DownConvBlock(128, 256)\n",
    "        self.down_conv_layer_4 = DownConvBlock(256, 512, dropout=0.5)\n",
    "        self.down_conv_layer_5 = DownConvBlock(512, 512, dropout=0.5)\n",
    "        self.down_conv_layer_6 = DownConvBlock(512, 512, dropout=0.5)\n",
    "        self.down_conv_layer_7 = DownConvBlock(512, 512, dropout=0.5)\n",
    "        self.down_conv_layer_8 = DownConvBlock(512, 512, norm=False, dropout=0.5)\n",
    "        self.up_conv_layer_1 = UpConvBlock(512, 512, dropout=0.5)\n",
    "        self.up_conv_layer_2 = UpConvBlock(1024, 512, dropout=0.5)\n",
    "        self.up_conv_layer_3 = UpConvBlock(1024, 512, dropout=0.5)\n",
    "        self.up_conv_layer_4 = UpConvBlock(1024, 512, dropout=0.5)\n",
    "        self.up_conv_layer_5 = UpConvBlock(1024, 256)\n",
    "        self.up_conv_layer_6 = UpConvBlock(512, 128)\n",
    "        self.up_conv_layer_7 = UpConvBlock(256, 64)\n",
    "        self.upsample_layer = nn.Upsample(scale_factor=2)\n",
    "        self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0))\n",
    "        self.conv_layer_1 = nn.Conv2d(128, chnls_op, 4, padding=1)\n",
    "        self.activation = nn.Tanh()\n",
    "    def forward(self, x):\n",
    "        enc1 = self.down_conv_layer_1(x)\n",
    "        enc2 = self.down_conv_layer_2(enc1)\n",
    "        enc3 = self.down_conv_layer_3(enc2)\n",
    "        enc4 = self.down_conv_layer_4(enc3)\n",
    "        enc5 = self.down_conv_layer_5(enc4)\n",
    "        enc6 = self.down_conv_layer_6(enc5)\n",
    "        enc7 = self.down_conv_layer_7(enc6)\n",
    "        enc8 = self.down_conv_layer_8(enc7)\n",
    "        dec1 = self.up_conv_layer_1(enc8, enc7)\n",
    "        dec2 = self.up_conv_layer_2(dec1, enc6)\n",
    "        dec3 = self.up_conv_layer_3(dec2, enc5)\n",
    "        dec4 = self.up_conv_layer_4(dec3, enc4)\n",
    "        dec5 = self.up_conv_layer_5(dec4, enc3)\n",
    "        dec6 = self.up_conv_layer_6(dec5, enc2)\n",
    "        dec7 = self.up_conv_layer_7(dec6, enc1)\n",
    "        final = self.upsample_layer(dec7)\n",
    "        final = self.zero_pad(final)\n",
    "        final = self.conv_layer_1(final)\n",
    "        return self.activation(final)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UpConvBlock(nn.Module):\n",
    "    def __init__(self, ip_sz, op_sz, dropout=0.0):\n",
    "        super(UpConvBlock, self).__init__()\n",
    "        self.layers = [\n",
    "            nn.ConvTranspose2d(ip_sz, op_sz, 4, 2, 1),\n",
    "            nn.InstanceNorm2d(op_sz),\n",
    "            nn.ReLU(),\n",
    "        ]\n",
    "        if dropout:\n",
    "            self.layers += [nn.Dropout(dropout)]\n",
    "    def forward(self, x, enc_ip):\n",
    "        x = nn.Sequential(*(self.layers))(x)\n",
    "        op = torch.cat((x, enc_ip), 1)\n",
    "        return op"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DownConvBlock(nn.Module):\n",
    "    def __init__(self, ip_sz, op_sz, norm=True, dropout=0.0):\n",
    "        super(DownConvBlock, self).__init__()\n",
    "        self.layers = [nn.Conv2d(ip_sz, op_sz, 4, 2, 1)]\n",
    "        if norm:\n",
    "            self.layers.append(nn.InstanceNorm2d(op_sz))\n",
    "        self.layers += [nn.LeakyReLU(0.2)]\n",
    "        if dropout:\n",
    "            self.layers += [nn.Dropout(dropout)]\n",
    "    def forward(self, x):\n",
    "        op = nn.Sequential(*(self.layers))(x)\n",
    "        return op"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Pix2PixDiscriminator(nn.Module):\n",
    "    def __init__(self, chnls_in=3):\n",
    "        super(Pix2PixDiscriminator, self).__init__()\n",
    "        def disc_conv_block(chnls_in, chnls_op, norm=1):\n",
    "            layers = [nn.Conv2d(chnls_in, chnls_op, 4, stride=2, padding=1)]\n",
    "            if normalization:\n",
    "                layers.append(nn.InstanceNorm2d(chnls_op))\n",
    "            layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
    "            return layers\n",
    "        self.lyr1 = disc_conv_block(chnls_in * 2, 64, norm=0)\n",
    "        self.lyr2 = disc_conv_block(64, 128)\n",
    "        self.lyr3 = disc_conv_block(128, 256)\n",
    "        self.lyr4 = disc_conv_block(256, 512)\n",
    "    def forward(self, real_image, translated_image):\n",
    "        ip = torch.cat((real_image, translated_image), 1)\n",
    "        op = self.lyr1(ip)\n",
    "        op = self.lyr2(op)\n",
    "        op = self.lyr3(op)\n",
    "        op = self.lyr4(op)\n",
    "        op = nn.ZeroPad2d((1, 0, 1, 0))(op)\n",
    "        op = nn.Conv2d(512, 1, 4, padding=1)(op)\n",
    "        return op"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Local)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
